from torch.utils.data._utils.collate import default_collate


class MultiScale(object):
    def __init__(self, img_key):
        self.img_key = img_key

    def __call__(self, batch):
        result_dict = {}
        for item in batch:
            size_key = tuple(item["img"].size())
            result_dict[size_key] = result_dict.get(size_key, []) + [item]
        for key, val in result_dict.items():
            result_dict[key] = default_collate(val)
        return result_dict

